In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

💾📈 1.5 - Визуализация и соревновательный анализ данных
¶

А зачем нужна визуализация в соревновательном анализе данных? 🤔

  • Качественная визуализация хорошо сочетается с генерацией и анализом признаков. Легко понять какие признаки будут полезны.
  • Визуализация может здорово помочь разобраться в том, как устроены данные или как работает ваша модель. Проверить наличие зависимости визуально.
  • В курсе мы будем часто использовать элементы графики для поиска инсайдов в данных.

Импорт библиотек¶

In [2]:
# !pip install --upgrade numpy pandas seaborn -q

import numpy as np
import pandas as pd

import warnings
warnings.simplefilter("ignore")

pd.set_option('display.max_columns', None)

Импорт данных¶

Обогатим данные по поездкам

In [3]:
rides_info = pd.read_csv("https://raw.githubusercontent.com/a-milenkin/Competitive_Data_Science/main/data/rides_info.csv")

print(rides_info.shape)

rides_info.head()
(739500, 14)
Out[3]:
user_id car_id ride_id ride_date rating ride_duration ride_cost speed_avg speed_max stop_times distance refueling user_ride_quality deviation_normal
0 o52317055h A-1049127W b1v 2020-01-01 4.95 21 268 36 113.548538 0 514.246920 0 1.115260 2.909
1 H41298704y A-1049127W T1U 2020-01-01 6.91 8 59 36 93.000000 1 197.520662 0 1.650465 4.133
2 v88009926E A-1049127W g1p 2020-01-02 6.01 20 315 61 81.959675 0 1276.328206 0 2.599112 2.461
3 t14229455i A-1049127W S1c 2020-01-02 0.26 19 205 32 128.000000 0 535.680831 0 3.216255 0.909
4 W17067612E A-1049127W X1b 2020-01-03 1.21 56 554 38 90.000000 1 1729.143367 0 2.716550 -1.822
In [4]:
cars = pd.read_csv("https://raw.githubusercontent.com/a-milenkin/Competitive_Data_Science/main/data/car_train.csv")

print(cars.shape)

cars.head()
(2337, 10)
Out[4]:
car_id model car_type fuel_type car_rating year_to_start riders year_to_work target_reg target_class
0 y13744087j Kia Rio X-line economy petrol 3.78 2015 76163 2021 108.53 another_bug
1 O41613818T VW Polo VI economy petrol 3.90 2015 78218 2021 35.20 electro_bug
2 d-2109686j Renault Sandero standart petrol 6.30 2012 23340 2017 38.62 gear_stick
3 u29695600e Mercedes-Benz GLC business petrol 4.04 2011 1263 2020 30.34 engine_fuel
4 N-8915870N Renault Sandero standart petrol 4.70 2012 26428 2017 30.45 engine_fuel
In [5]:
driver_info = pd.read_csv("https://raw.githubusercontent.com/a-milenkin/Competitive_Data_Science/main/data/driver_info.csv")

print(driver_info.shape)

driver_info.head()
(15153, 7)
Out[5]:
age user_rating user_rides user_time_accident user_id sex first_ride_date
0 27 9.0 865 19.0 l17437965W 1 2019-4-2
1 46 7.9 2116 11.0 Z12362316j 0 2021-11-19
2 59 7.8 947 4.0 g11098715c 0 2021-1-15
3 37 7.0 18 4.0 U12618125q 0 2019-11-20
4 39 8.2 428 7.0 A14375829B 0 2019-7-23
In [6]:
# Загружаем уже знакомый нам датасет
rides_info = pd.read_csv("https://raw.githubusercontent.com/a-milenkin/Competitive_Data_Science/main/data/rides_info.csv")
cars = pd.read_csv("https://raw.githubusercontent.com/a-milenkin/Competitive_Data_Science/main/data/car_train.csv")
driver_info = pd.read_csv("https://raw.githubusercontent.com/a-milenkin/Competitive_Data_Science/main/data/driver_info.csv")

rides_info = rides_info.merge(cars, on="car_id", how="right")
print(rides_info.shape)
rides_info = rides_info.merge(driver_info, on="user_id", how="left")
print(rides_info.shape)
(406638, 23)
(406638, 29)

Встроенные элементы визуализации в Pandas
¶

Для визуализации данных из датасета необязательно использовать сторонние фреймворки,
в Pandas есть несколько встроенных функций, которыми достаточно просто пользоваться.

In [7]:
print(rides_info.shape)

rides_info.head(3)
(406638, 29)
Out[7]:
user_id car_id ride_id ride_date rating ride_duration ride_cost speed_avg speed_max stop_times distance refueling user_ride_quality deviation_normal model car_type fuel_type car_rating year_to_start riders year_to_work target_reg target_class age user_rating user_rides user_time_accident sex first_ride_date
0 n14703870u y13744087j Q1Z 2020-01-01 5.72 220 3514 42 NaN 6 1.682556e+03 0 0.524750 0.0 Kia Rio X-line economy petrol 3.78 2015 76163 2021 108.53 another_bug 38 7.4 268 2.0 0 2019-9-7
1 W18144322F y13744087j M1P 2020-01-01 2.52 37392 523483 45 53.0 2 1.711379e+06 0 1.723151 0.0 Kia Rio X-line economy petrol 3.78 2015 76163 2021 108.53 another_bug 46 6.7 643 3.0 0 2020-7-28
2 Q11878237R y13744087j D1j 2020-01-02 7.17 45 444 54 82.0 0 9.523155e+02 0 0.876440 -0.0 Kia Rio X-line economy petrol 3.78 2015 76163 2021 108.53 another_bug 49 8.4 161 NaN 0 2020-3-23

Смотрим на распределение через `df.hist( )`
¶

df.hist( ) - строит гистограмму распределения данных по числовым столбцам

In [8]:
rides_info.hist();  # Вызываем функцию hist( )

Видно, что графики получились мелкие и текстовая информация накладывается друг на друга, также столбец пол определился как числовой, хотя является категориальным. Исправим это, добавив параметр figsize и отсеим ненужный столбец. Так же, чтобы не выводились служебные строки перед графиками можно добавить None или ; в конце ячейки.

In [9]:
# Добавили ; в конце, чтобы не выводилась служебная информация
rides_info.drop("sex", axis=1).hist(figsize=(20, 15), layout=(-1, 5));

Распределение через боксплот `df.boxplot( )`
¶

df.boxplot( ) - функция, позволяющая рисовать, так называемые, "ящики с усами", показывающие среднее значение, стандартные отклонения и разброс признака на одном графике. Давайте отобразим вид поломки, в зависимоти от рейтинга водителя, также поменяем параметр fontsize, отвечающий за размер шрифта.

In [10]:
rides_info.boxplot(
    column=["user_rating"], by="target_class", fontsize=8, figsize=(20, 5)
);

Тепловые карты зависимостей численных и категориальных переменных
¶

In [13]:
corr = rides_info.corr(numeric_only=True).round(2)
corr.style.background_gradient(cmap="RdYlGn")
Out[13]:
  rating ride_duration ride_cost speed_avg speed_max stop_times distance refueling user_ride_quality deviation_normal car_rating year_to_start riders year_to_work target_reg age user_rating user_rides user_time_accident sex
rating 1.000000 0.000000 0.000000 -0.090000 -0.230000 -0.060000 -0.000000 0.000000 -0.000000 -0.040000 0.000000 0.000000 0.000000 0.000000 -0.000000 0.000000 0.000000 -0.000000 0.000000 -0.000000
ride_duration 0.000000 1.000000 0.920000 -0.000000 0.000000 0.010000 0.960000 0.000000 0.000000 -0.000000 0.000000 0.000000 0.000000 -0.000000 -0.000000 -0.000000 -0.000000 -0.000000 -0.000000 0.000000
ride_cost 0.000000 0.920000 1.000000 -0.000000 -0.000000 0.010000 0.880000 0.000000 0.000000 0.000000 -0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.010000 0.000000 -0.000000 -0.000000
speed_avg -0.090000 -0.000000 -0.000000 1.000000 0.500000 -0.130000 0.060000 0.000000 -0.010000 0.040000 -0.000000 -0.000000 -0.000000 0.000000 -0.020000 0.000000 -0.000000 -0.000000 -0.020000 -0.000000
speed_max -0.230000 0.000000 -0.000000 0.500000 1.000000 0.050000 0.030000 -0.000000 -0.010000 0.090000 0.000000 -0.010000 -0.010000 -0.000000 0.150000 -0.000000 0.010000 -0.000000 0.000000 0.000000
stop_times -0.060000 0.010000 0.010000 -0.130000 0.050000 1.000000 -0.010000 -0.000000 0.020000 0.000000 0.020000 0.010000 0.010000 0.000000 0.190000 -0.000000 0.030000 -0.000000 0.100000 0.010000
distance -0.000000 0.960000 0.880000 0.060000 0.030000 -0.010000 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 -0.010000 -0.000000 -0.000000 -0.000000 -0.000000 0.000000
refueling 0.000000 0.000000 0.000000 0.000000 -0.000000 -0.000000 0.000000 1.000000 0.000000 -0.000000 -0.000000 0.000000 0.000000 -0.000000 0.000000 0.000000 -0.000000 -0.000000 -0.000000 -0.000000
user_ride_quality -0.000000 0.000000 0.000000 -0.010000 -0.010000 0.020000 0.000000 0.000000 1.000000 -0.010000 0.020000 -0.020000 -0.020000 0.020000 0.030000 -0.000000 -0.000000 0.000000 0.010000 0.000000
deviation_normal -0.040000 -0.000000 0.000000 0.040000 0.090000 0.000000 0.000000 -0.000000 -0.010000 1.000000 0.020000 0.010000 0.010000 -0.000000 0.100000 0.000000 0.030000 0.000000 0.050000 -0.000000
car_rating 0.000000 0.000000 -0.000000 -0.000000 0.000000 0.020000 0.000000 -0.000000 0.020000 0.020000 1.000000 -0.020000 -0.010000 -0.020000 0.020000 0.000000 0.000000 -0.000000 0.010000 -0.000000
year_to_start 0.000000 0.000000 0.000000 -0.000000 -0.010000 0.010000 0.000000 0.000000 -0.020000 0.010000 -0.020000 1.000000 0.990000 0.060000 0.010000 0.020000 0.020000 0.010000 0.000000 -0.010000
riders 0.000000 0.000000 0.000000 -0.000000 -0.010000 0.010000 0.000000 0.000000 -0.020000 0.010000 -0.010000 0.990000 1.000000 0.050000 0.000000 0.020000 0.010000 0.010000 0.000000 -0.010000
year_to_work 0.000000 -0.000000 0.000000 0.000000 -0.000000 0.000000 0.000000 -0.000000 0.020000 -0.000000 -0.020000 0.060000 0.050000 1.000000 0.030000 0.010000 0.010000 0.010000 0.000000 -0.010000
target_reg -0.000000 -0.000000 0.010000 -0.020000 0.150000 0.190000 -0.010000 0.000000 0.030000 0.100000 0.020000 0.010000 0.000000 0.030000 1.000000 0.010000 0.090000 0.010000 0.120000 -0.000000
age 0.000000 -0.000000 0.000000 0.000000 -0.000000 -0.000000 -0.000000 0.000000 -0.000000 0.000000 0.000000 0.020000 0.020000 0.010000 0.010000 1.000000 -0.110000 -0.020000 -0.130000 -0.760000
user_rating 0.000000 -0.000000 0.010000 -0.000000 0.010000 0.030000 -0.000000 -0.000000 -0.000000 0.030000 0.000000 0.020000 0.010000 0.010000 0.090000 -0.110000 1.000000 0.000000 0.250000 0.150000
user_rides -0.000000 -0.000000 0.000000 -0.000000 -0.000000 -0.000000 -0.000000 -0.000000 0.000000 0.000000 -0.000000 0.010000 0.010000 0.010000 0.010000 -0.020000 0.000000 1.000000 -0.010000 0.020000
user_time_accident 0.000000 -0.000000 -0.000000 -0.020000 0.000000 0.100000 -0.000000 -0.000000 0.010000 0.050000 0.010000 0.000000 0.000000 0.000000 0.120000 -0.130000 0.250000 -0.010000 1.000000 0.170000
sex -0.000000 0.000000 -0.000000 -0.000000 0.000000 0.010000 0.000000 -0.000000 0.000000 -0.000000 -0.000000 -0.010000 -0.010000 -0.010000 -0.000000 -0.760000 0.150000 0.020000 0.170000 1.000000
In [15]:
pd.crosstab(
    rides_info["target_class"],
    rides_info["model"],
    #margins = True,
    normalize = True, 
).style.background_gradient(cmap="RdYlGn")
Out[15]:
model Audi A3 Audi A4 Audi Q3 BMW 320i Fiat 500 Hyundai Solaris Kia Rio Kia Rio X Kia Rio X-line Kia Sportage MINI CooperSE Mercedes-Benz E200 Mercedes-Benz GLC Mini Cooper Nissan Qashqai Renault Kaptur Renault Sandero Skoda Rapid Smart Coupe Smart ForFour Smart ForTwo Tesla Model 3 VW Polo VW Polo VI VW Tiguan Volkswagen ID.4
target_class                                                    
another_bug 0.000856 0.001284 0.001284 0.000856 0.000000 0.010697 0.006846 0.008130 0.007702 0.009414 0.002567 0.000428 0.001284 0.000428 0.006846 0.006846 0.009414 0.006846 0.004707 0.008986 0.002995 0.000856 0.003851 0.008986 0.006846 0.000428
break_bug 0.001284 0.001712 0.000000 0.000428 0.001284 0.005563 0.005991 0.005563 0.005135 0.008986 0.001712 0.001284 0.000856 0.000428 0.004279 0.008558 0.008130 0.010270 0.006846 0.007702 0.006846 0.000428 0.008558 0.005563 0.007702 0.000428
electro_bug 0.001284 0.000428 0.002567 0.000428 0.000428 0.005563 0.005135 0.005991 0.007274 0.005563 0.000000 0.001284 0.002139 0.000856 0.007274 0.005991 0.005135 0.007274 0.005563 0.005563 0.008130 0.000856 0.006418 0.007702 0.005991 0.001712
engine_check 0.000856 0.001284 0.000428 0.002139 0.001284 0.007274 0.004707 0.006846 0.005563 0.007702 0.001284 0.002567 0.000000 0.001712 0.005991 0.009414 0.005135 0.008558 0.005135 0.006846 0.006418 0.000856 0.009842 0.007702 0.005991 0.000000
engine_fuel 0.000856 0.000856 0.001712 0.000428 0.000000 0.009414 0.004707 0.004707 0.007702 0.008130 0.000000 0.000856 0.000856 0.002139 0.008130 0.009414 0.009414 0.005563 0.005135 0.007702 0.005991 0.001284 0.007702 0.002995 0.005563 0.000856
engine_ignition 0.000856 0.000856 0.000428 0.001284 0.000428 0.007274 0.005135 0.008558 0.005991 0.006418 0.000856 0.000000 0.000000 0.001284 0.012837 0.004707 0.009842 0.006846 0.007274 0.008558 0.005991 0.000856 0.005135 0.006418 0.006846 0.000428
engine_overheat 0.000428 0.001284 0.001712 0.000428 0.002139 0.008558 0.004707 0.006418 0.006846 0.008130 0.001712 0.002139 0.001284 0.002139 0.009842 0.010270 0.010270 0.004279 0.004707 0.007274 0.006846 0.000000 0.006846 0.007274 0.006846 0.001284
gear_stick 0.000856 0.000856 0.000856 0.000856 0.002139 0.007702 0.005563 0.006418 0.007274 0.008558 0.000856 0.001284 0.000856 0.000428 0.007274 0.010697 0.007702 0.007274 0.006418 0.008986 0.007702 0.000856 0.006418 0.008558 0.004707 0.000428
wheel_shake 0.000000 0.000000 0.000000 0.000000 0.000000 0.006846 0.004707 0.008558 0.006418 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.005991 0.004707 0.005991 0.004707 0.000000 0.005563 0.009842 0.007274 0.000000

И все это не выходя из Pandas !¶

df.plot( ) - функция предоставляющая весь функционал библиотеки Matplotlib, так же можно через точку указывать конкретный тип желаемого графика, например: df.plot.bar( ) или df.plot.scatter( ). На наш взгляд, пользоваться ей не очень удобно, поэтому не будем на ней останавливаться, подробнее можно ознакомиться в документации. Сразу перейдем к более продвинутым инструментам!

seaborn - это все что вам надо!
¶

In [16]:
import seaborn as sns

#sns.set_theme()  # Тут можно задать стили

`sns.scatterplot()` + `sns.lineplot()` + 💪 = `sns.relplot`(`kind = ...`)
¶

Функция relplot() объединяет в себе функционал scatterplot() (диаграмма рассеяния) и lineplot(), смена происходит переключателем kind.
Таким образом можно держать в голове меньше разных названий функций и параметров.

Более того, plt.figure(figsize=(20, 5)) больше не нужен, ведь есть параметры height и aspect.

Диаграмма рассеяния — это диаграмма, которая отображает точки на основе двух измерений набора данных.

In [51]:
g = sns.relplot(
    data=rides_info,
    x="ride_date",
    y="deviation_normal",
    hue="target_class",
    kind="line",  # или scatter
    aspect=4, 
    #height=30
)

g.set_xticklabels(rotation=45, horizontalalignment="right", step=2);

По оси х мы вытянули время. По оси y исследуемый столбец, в данном случае deviation_normal. Все точки окрасили в соответствующий класс для наглядности.

Становится видно, что величина deviation_normal ведет себя по-разному в зависимости от класса будущей поломки. Например, точку перегиба в центре графика для класса engine_ignition можно использовать как вспомогательный признак.

In [15]:
## Импортируем вспомогтаельные фрагменты из matplolib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.figure(figsize=(20, 5))

g = sns.lineplot(
    data=rides_info,
    x="ride_date",
    y="deviation_normal",
    hue="target_class",
)

g.set_xticklabels(g.get_xticklabels(), rotation=45, horizontalalignment="right")
g.xaxis.set_major_locator(ticker.MultipleLocator(2));
In [16]:
g = sns.relplot(
    data=rides_info,
    x="ride_date",
    y="user_ride_quality",
    kind="line",
    hue="target_class",
    aspect=4,
)

g.set_xticklabels(rotation=45, horizontalalignment="right", step=2);

Аналогично растянули признак user_ride_quality во времени по оси y как исследуемый столбец. Все точки мы окрасили в соответствующий класс для наглядности.

Заметно, что величина user_ride_quality ведет себя по-разному в зависимости от класса будущей поломки. Например, стартовая точка у разных классов находится на разной высоте.

In [53]:
# Отберем информацию только про 10 машин
tmp = rides_info[rides_info["car_id"].isin(rides_info.car_id.unique()[:10])]
tmp.shape
Out[53]:
(1740, 29)
In [58]:
# Как ведет себя deviation_normal во времени для 10 автомобилей

g = sns.relplot(
    data=tmp,

    x="ride_date",
    y="deviation_normal",

    hue="target_class",
    style="car_id",

    legend=True,
    kind="line",
    aspect=4,
)
g.set_xticklabels(rotation=45, horizontalalignment="right", step=2);

Какие признаки могут помочь различить классы поломок?¶

  • Замечаем точки перегиба
  • Замечаем точки входа
  • Возможно углы наклона до перегиба

Все это можно можно будет использовать, как доп. признаки!¶

In [59]:
# Вместо известного нам:

plt.figure(figsize=(20, 5))
sns.scatterplot(data = rides_info,
                 x = 'ride_duration',
                 y = 'ride_cost',
                 hue = 'model'
                );
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[59], line 3
      1 # Вместо известного нам:
----> 3 plt.figure(figsize=(20, 5))
      4 sns.scatterplot(data = rides_info,
      5                  x = 'ride_duration',
      6                  y = 'ride_cost',
      7                  hue = 'model'
      8                 )

NameError: name 'plt' is not defined
In [20]:
g = sns.relplot(
    data=rides_info,
    x="ride_duration",
    y="ride_cost",
    hue="target_class",

    kind="scatter",
    aspect=4,
    alpha=0.5,
);

С помощью более функционального sns.relplot() пишем меньше кода и быстрее анализируем. Кстати, по графику заметно, что угол наклона у разных машин разный.

Возможно feature = ride_cost/ride_duration - это отличный доп.признак!

Имеет смысл поиграть с такими параметриами как row и col, чтобы построить этот график в разрезах по другим категориальным признакам

In [21]:
g = sns.relplot(
    data=rides_info,
    x="target_class",

    y="user_time_accident",

    hue="car_type",
    kind="scatter",
    aspect=4,
    alpha=0.5,
);

Запоминайте, полезно растягивать исследуемые признаки не только вдоль времени или численных признаков, но еще и по категориальным переменным.

В этом примере по оси X таргет, а цветом окрашен сегмент машины.

In [22]:
sns.relplot(
    data=rides_info,
    x="rating",
    y="user_time_accident",
    hue="target_class",
    kind="scatter",
    aspect=4,
    alpha=0.5,
);

Еще один интересный инсайт того, как число инцидентов водителя зависит от его рейтинга.

Легко заметить, что некоторые классы поломок свойственны поездкам с низким рейтингом.

🧠 Очень рекомендуем применять метод sns.relplot().

Чтобы эффективнее исследовать данные:¶

  • 🔴 Растягивайте исследуемый признак во времени
  • 🟢 Стройте признак относительно другого признака / частоты
  • 🟡 Раскрашивайте точки в категории / таргет (параметры: hue / style / size)

🔥 Построение распределений и sns.displot() 🔥
¶

`sns.displot`(`kind=...`) = `sns.histplot()`  + `kdeplot()` + `ecdfplot()` + 💪

В Seaborn функция displot() объединяет в себе функционал histplot(), kdeplot() и ecdfplot() переключение происходит с помощью kind.
Получается, что снова можно держать в голове меньше разных названий функций и параметров. И снова plt.figure(figsize=(20, 5)) не нужен, ведь есть параметры height и aspect.

In [23]:
# plt.figure(figsize=(20, 5)) # уже можно не использовать, ведь есть параметр aspect
g = sns.displot(
    data=rides_info,

    x="user_ride_quality",
    # y="user_time_accident",

    hue="target_class",
    legend=True,
    aspect=4,
    kind="hist",  # kde
    alpha=0.5,
);

🧠 🧠 Многофункциональный метод sns.displot() - это вторая вещь, после sns.relplot() которую вам нужно запомнить и использовать при возможности!

In [24]:
tmp = rides_info[rides_info["car_id"].isin(rides_info.sample(100, random_state=10).car_id.unique()[:20])]

tmp.car_id.nunique()
Out[24]:
20
In [25]:
g = sns.displot(
    data=tmp,

    x="user_ride_quality",
    y="deviation_normal",

    aspect=1,
    #kind="kde",
    alpha=0.5,

    hue="car_id",
    col="target_class",

    col_wrap=4,

).set_xticklabels(rotation=45, horizontalalignment="right");

🔥 Категориальные данные и `sns.catplot()`
¶

В Seaborn функция catplot() объединяет в себе функционал boxplot(), violinplot(), boxenplot, stripplot, swarmplot, а еще pointplot, barplot и countplot. Переключение происходит параметром kind.
Получается, что снова можно держать в голове меньше разных названий функций и параметров.

In [26]:
sns.catplot(
    data=rides_info,
    x="target_class",
    y="distance",

    aspect=4,
    hue="car_type",
    alpha=0.25,

).set_xticklabels(rotation=45, horizontalalignment="right");
In [27]:
# Варируем в параметр kind, получаем любой вид графики

sns.catplot(
    data=rides_info,
    x="target_class",
    y="user_ride_quality",
    hue="car_type",
    aspect=4,
    kind="boxen",
).set_xticklabels(rotation=45, horizontalalignment="right");

Вывод: Сходу замечаем, что поломки класса wheel_shake не происходят в трех сегментах машин.

🧠 🧠 🧠 Многофункциональный метод sns.catplot() это третья вещь, после sns.displot() и sns.relplot() которую вам нужно запомнить и использовать при возможности!

🔥 Попарные корреляции и тепловая карта - `sns.heatmap()`
¶

Функция sns.heatmap() чаще всего используется для отрисовки наглядной матрицы корреляций признаков. Хотя можно её использовать, когда хочется подсветить какую угодно таблицу значений.

Посмотрим как её построить на нашем датасете. Сначала посмотрим матрицу корреляций, котрорую выводит pandas:

In [28]:
rides_info.corr(numeric_only=True)
Out[28]:
rating ride_duration ride_cost speed_avg speed_max stop_times distance refueling user_ride_quality deviation_normal car_rating year_to_start riders year_to_work target_reg age user_rating user_rides user_time_accident sex
rating 1.000000 0.001227 0.001568 -0.086294 -0.234631 -0.055198 -0.003871 0.001678 -0.000310 -0.042613 0.003083 0.000846 0.000837 0.000607 -0.003452 0.002053 0.003308 -0.002442 0.000104 -0.004693
ride_duration 0.001227 1.000000 0.915057 -0.001735 0.000430 0.009862 0.964560 0.004879 0.000633 -0.000385 0.000894 0.000103 0.000352 -0.000472 -0.002801 -0.002866 -0.000477 -0.000353 -0.001229 0.001111
ride_cost 0.001568 0.915057 1.000000 -0.001216 -0.000034 0.007818 0.882981 0.004250 0.001506 0.000323 -0.000041 0.002257 0.002378 0.004566 0.007067 0.000046 0.005591 0.001068 -0.000278 -0.000168
speed_avg -0.086294 -0.001735 -0.001216 1.000000 0.496602 -0.129687 0.064636 0.000575 -0.007180 0.040983 -0.004753 -0.001666 -0.002308 0.000170 -0.024530 0.001540 -0.002659 -0.000611 -0.022754 -0.000349
speed_max -0.234631 0.000430 -0.000034 0.496602 1.000000 0.051144 0.032024 -0.000847 -0.008618 0.087402 0.000648 -0.012698 -0.014007 -0.003338 0.149425 -0.001196 0.006604 -0.001281 0.004410 0.002147
stop_times -0.055198 0.009862 0.007818 -0.129687 0.051144 1.000000 -0.006981 -0.000387 0.020246 0.004780 0.018989 0.008198 0.009124 0.002615 0.193921 -0.003088 0.032055 -0.001420 0.098434 0.005554
distance -0.003871 0.964560 0.882981 0.064636 0.032024 -0.006981 1.000000 0.003681 0.000124 0.002246 0.000241 0.000411 0.000757 0.000053 -0.005720 -0.002923 -0.000643 -0.000154 -0.003351 0.001479
refueling 0.001678 0.004879 0.004250 0.000575 -0.000847 -0.000387 0.003681 1.000000 0.001204 -0.001345 -0.002619 0.002083 0.001635 -0.003023 0.000861 0.001442 -0.000931 -0.001680 -0.001183 -0.002582
user_ride_quality -0.000310 0.000633 0.001506 -0.007180 -0.008618 0.020246 0.000124 0.001204 1.000000 -0.011180 0.018024 -0.019891 -0.020917 0.021576 0.025973 -0.003643 -0.002574 0.001516 0.005087 0.003929
deviation_normal -0.042613 -0.000385 0.000323 0.040983 0.087402 0.004780 0.002246 -0.001345 -0.011180 1.000000 0.021623 0.009543 0.012239 -0.004786 0.100291 0.000671 0.027107 0.000613 0.054867 -0.000229
car_rating 0.003083 0.000894 -0.000041 -0.004753 0.000648 0.018989 0.000241 -0.002619 0.018024 0.021623 1.000000 -0.016672 -0.013615 -0.017544 0.024515 0.002611 0.003010 -0.000529 0.013035 -0.001380
year_to_start 0.000846 0.000103 0.002257 -0.001666 -0.012698 0.008198 0.000411 0.002083 -0.019891 0.009543 -0.016672 1.000000 0.986746 0.059072 0.006390 0.016189 0.015626 0.009642 0.002742 -0.009968
riders 0.000837 0.000352 0.002378 -0.002308 -0.014007 0.009124 0.000757 0.001635 -0.020917 0.012239 -0.013615 0.986746 1.000000 0.048270 0.003259 0.015917 0.014912 0.009540 0.003498 -0.009730
year_to_work 0.000607 -0.000472 0.004566 0.000170 -0.003338 0.002615 0.000053 -0.003023 0.021576 -0.004786 -0.017544 0.059072 0.048270 1.000000 0.034604 0.011401 0.014833 0.013002 0.003461 -0.006012
target_reg -0.003452 -0.002801 0.007067 -0.024530 0.149425 0.193921 -0.005720 0.000861 0.025973 0.100291 0.024515 0.006390 0.003259 0.034604 1.000000 0.006725 0.088232 0.008236 0.122316 -0.001264
age 0.002053 -0.002866 0.000046 0.001540 -0.001196 -0.003088 -0.002923 0.001442 -0.003643 0.000671 0.002611 0.016189 0.015917 0.011401 0.006725 1.000000 -0.109918 -0.019438 -0.130837 -0.757736
user_rating 0.003308 -0.000477 0.005591 -0.002659 0.006604 0.032055 -0.000643 -0.000931 -0.002574 0.027107 0.003010 0.015626 0.014912 0.014833 0.088232 -0.109918 1.000000 0.003577 0.252693 0.151911
user_rides -0.002442 -0.000353 0.001068 -0.000611 -0.001281 -0.001420 -0.000154 -0.001680 0.001516 0.000613 -0.000529 0.009642 0.009540 0.013002 0.008236 -0.019438 0.003577 1.000000 -0.005332 0.018100
user_time_accident 0.000104 -0.001229 -0.000278 -0.022754 0.004410 0.098434 -0.003351 -0.001183 0.005087 0.054867 0.013035 0.002742 0.003498 0.003461 0.122316 -0.130837 0.252693 -0.005332 1.000000 0.172690
sex -0.004693 0.001111 -0.000168 -0.000349 0.002147 0.005554 0.001479 -0.002582 0.003929 -0.000229 -0.001380 -0.009968 -0.009730 -0.006012 -0.001264 -0.757736 0.151911 0.018100 0.172690 1.000000

Можно ориентироваться, но достаточно сложно при большом количестве признаков.
Теперь отрисуем матрицу с помощью sns.heatmap( ):

In [29]:
# Размер будущей тепловой карты. Можно указать в коде разово
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (12, 12)
In [30]:
sns.heatmap(
    data=rides_info.corr(numeric_only=True).round(2),
    square=True,
    annot=True,
);

Добавим красоты¶

In [31]:
heatmap = sns.heatmap(
    rides_info.corr(numeric_only=True).round(2),
    annot=True,
    square=True,

    cmap="Blues",  # использовать синию цветовую карту
    cbar_kws={"fraction": 0.01},  # боковой колор-бар (shrink colour bar)
    linewidth=2,  # пространство между клетками
)

heatmap.set_xticklabels(
    heatmap.get_xticklabels(), rotation=45, horizontalalignment="right"
);

Видно, что части снизу и сверху от диагонали идентичны, и новой инфомации не несут, создавая визуальный шум - уберем одну из них

In [32]:
#sns.set_style("whitegrid")

# Воспользуемся функциями np.triu, чтобы изолировать верхний треугольник (np.tril нижний)
# функция np.ones_like() изменит все изолированные значения на 1.
mask = np.triu(np.ones_like(rides_info.corr(numeric_only=True), dtype=bool))

heatmap = sns.heatmap(
    rides_info.corr(numeric_only=True).round(2),
    annot=True,
    square=True,
    cmap="BrBG",
    cbar_kws={"fraction": 0.01},
    linewidth=1,

    mask=mask,
)

heatmap.set_title(
    "Треугольная тепловая карта корреляции", fontdict={"fontsize": 18}, pad=16
);

Связка `pd.crosstab` + `sns.heatmap` = 🔥
¶

In [33]:
# Размер будущей тепловой карты. Можно указать в коде разово
plt.rcParams["figure.figsize"] = (25, 25)
In [34]:
crst = pd.crosstab(
    rides_info["target_class"],
    rides_info["model"],
    normalize=True,
).round(4)


heatmap = sns.heatmap(
    crst,
    annot=True,
    square=True,
    cmap="BrBG",
    cbar_kws={"fraction": 0.01},
    linewidth=1,
)

heatmap.set_title("Тепловая карта совстречаемости", fontdict={"fontsize": 18}, pad=16);

🔥 Парные зависимости между переменными и `sns.pairplot()`
¶

In [35]:
sns.pairplot(
    rides_info.sample(1000),
    vars=["speed_max", "distance", "speed_avg"],
    corner=True,
    hue="target_class",
)
sns.despine();

`sns.lmplot` = `sns.pairplot` + умение провести прямую
¶

In [36]:
sns.lmplot(
    data=rides_info.sample(1000),
    x="speed_avg",
    y="speed_max",

    col="target_class",

    col_wrap=3,
);
In [37]:
sns.lmplot(
    data=rides_info,
    x="ride_duration",
    y="ride_cost",
    col="target_class",

    hue="car_type",

    col_wrap=5,
);

Композиция нескольких видов графиков и `sns.jointplot()`
¶

In [38]:
# Диаграмма рассеяния + распределения

sns.jointplot(
    data=rides_info.sample(10_000),
    x="target_reg",
    y="rating",
    hue="target_class",
    height=10,
    legend=False,
    ratio=2,
    kind="scatter",
);

Карта графики Seaborn¶

Итоги и выводы:¶

  • Способов визуализации много. Иногда можно даже не выходить из Pandas
  • Seaborn очень функциональный и понятный фреймворк. Для соревновательноо DS этого достаточно.
  • Запомните хотя бы три самых важных для себя функции и пользуйтесь

например: sns.relplot / sns.displot / sns.heatmap

  • Растягивайте признак по времени, по частоте или другому признаку.
  • Раскрашивайте в таргет!

Спасибо за внимание!
¶

In [39]:
cars
Out[39]:
car_id model car_type fuel_type car_rating year_to_start riders year_to_work target_reg target_class
0 y13744087j Kia Rio X-line economy petrol 3.78 2015 76163 2021 108.53 another_bug
1 O41613818T VW Polo VI economy petrol 3.90 2015 78218 2021 35.20 electro_bug
2 d-2109686j Renault Sandero standart petrol 6.30 2012 23340 2017 38.62 gear_stick
3 u29695600e Mercedes-Benz GLC business petrol 4.04 2011 1263 2020 30.34 engine_fuel
4 N-8915870N Renault Sandero standart petrol 4.70 2012 26428 2017 30.45 engine_fuel
... ... ... ... ... ... ... ... ... ... ...
2332 j21246192N Smart ForFour economy petrol 4.38 2017 121239 2018 25.48 wheel_shake
2333 h-1554287F Audi A4 premium petrol 4.30 2016 107793 2020 69.26 engine_check
2334 A15262612g Kia Rio economy petrol 3.88 2015 80234 2019 46.03 gear_stick
2335 W-2514493U Renault Sandero standart petrol 4.50 2014 60048 2020 77.19 another_bug
2336 z-1337463D VW Polo economy petrol 3.94 2015 92312 2016 54.68 engine_check

2337 rows × 10 columns

In [40]:
sns.catplot(
    data=cars,
    x='car_type',
    aspect=4, 
    kind="count"
    )

sns.despine();
In [41]:
sns.catplot(
    data=cars,
    x='target_class',
    aspect=4, 
    kind="count"
    )

sns.despine();
In [42]:
cars.model.unique()

sns.catplot(
    data=cars.query("model == 'MINI CooperSE'"),
    x='target_class',
    aspect=4, 
    kind="count"
    )

sns.despine();
In [43]:
cars.model.unique()

sns.catplot(
    data=cars.query("target_class == 'wheel_shake'"),
    x='car_type',
    aspect=4, 
    kind="count"
    )

sns.despine();
In [44]:
sns.catplot(
    data=cars.query("model == 'Nissan Qashqai'"),
    x='target_class',
    aspect=4, 
    kind="count"
    )

sns.despine();
In [45]:
sns.catplot(data=rides_info,
            x='target_class',
            y='distance',
            hue='car_type',
            aspect=4,
            alpha=0.5,
           );
In [46]:
sns.relplot(data=rides_info.head(10),
            kind='scatter',
            x='target_class',
            y='distance',
            aspect=4, 
            alpha=0.5, 
            hue='car_type',
           );
In [47]:
rides_info.fuel_type.unique()
Out[47]:
array(['petrol', 'electro'], dtype=object)
In [49]:
import matplotlib.pyplot as plt

params = {'data' : rides_info,
          'kind' : 'scatter',
          'x' : 'ride_date',
          'y' : 'user_time_accident',
          'row' : 'fuel_type',
          'hue' : 'target_class',
          'size': 'car_type',
          'aspect' : 4, 
          'alpha' : 0.5}

g = sns.relplot(**params)
plt.xticks(rotation=45);
In [ ]: